# -*- coding: utf-8 -*-
import numpy as np
import copy
import torch
from torch.autograd import Variable
import time
from .ufedbase import UnlearnBasicClient, UnlearnBasicServer
import numpy as np
from utils import fmodule
from tqdm import tqdm
class Server(UnlearnBasicServer):
    def __init__(self, option, model, clients, data_loader, device=None):
        self.de_model = copy.deepcopy(model).to(device)
        super(Server, self).__init__(option, model, clients, data_loader, device)
        self.lam = 0.5
        self.rde_round = int(0.7 * self.u_rounds)
        self.save_name += ' lam' + str(self.lam)

    def run(self):
        self.current_rounds = 0
        test_metric = self.test_on_clients(dataflag='test', model=self.model)
        self.outFunc(t_metric=test_metric)

        if self.stage == 'Pretrain':
            for round in tqdm(range(1, self.num_rounds + 1), desc='Pretraining Rounds'):
                self.current_rounds = round
                # federated train
                before_allocated_memory = torch.cuda.memory_allocated(0)
                self.iterate()
                after_allocated_memory = torch.cuda.memory_allocated(0)
                print(f'iterate_allocate{(after_allocated_memory - before_allocated_memory) / 1024 ** 2:.2f}')

                # 在这里save防止global broadcast影响结果
                before_allocated_memory = torch.cuda.memory_allocated(0)
                self.pretrain_save()
                after_allocated_memory = torch.cuda.memory_allocated(0)
                print(f'save_allocate{(after_allocated_memory - before_allocated_memory) / 1024 ** 2:.2f}')

                # syn
                self.global_lr_scheduler(self.num_rounds)

                test_metric = self.test_on_clients(dataflag='test', model=self.model)
                self.outFunc(test_metric)
                self.save_log(self.out_log)
            self.save_ckp()

        if self.stage == 'Unlearn':
            # set client attr

            for round in tqdm(range(1, self.u_rounds + 1), desc='Unlearning Rounds'):
                self.current_rounds = round
                # federated unlearn
                self.unlearn_iterate()  # including global model update
                # syn
                self.global_lr_scheduler(self.num_rounds)
                if round % 2 == 0:
                    test_metric = self.test_on_clients(dataflag='test', model=self.model)
                    self.outFunc(test_metric)
                    self.save_log(self.out_log)

            self.stage = 'PT'
            for round in tqdm(range(1, self.p_rounds + 1), desc='Post-training Rounds'):
                self.current_rounds = round
                # federated post training
                self.pt_iterate()
                # syn
                # syn
                self.global_lr_scheduler(self.p_rounds)

                test_metric = self.test_on_clients(dataflag='test', model=self.model)
                self.outFunc(test_metric)
                self.save_log(self.out_log)
            self.save_ckp()

    def unlearn_iterate(self):
        if self.current_rounds <= self.rde_round and self.current_rounds % 2 == 1:
            # self.de_model
            reply = self.communicate(self.unlearn_clients_id, self.de_model)
            de_models, de_losses = reply['model'], reply['loss']
            self.de_model = self.lam * self.model + (1-self.lam) * self.aggregate(de_models)
            for uc in self.unlearn_clients:
                uc.de_model = copy.deepcopy(self.de_model)
            del de_models

        else:
            self.selected_clients = list(range(self.num_clients))
            reply = self.communicate(self.selected_clients)
            models, losses = reply['model'], reply['loss']
            self.model = self.aggregate(models)
            del models
        return

class Client(UnlearnBasicClient):
    def __init__(self, option, id, model=None):
        super(Client, self).__init__(option, id, model)
        self.de_model = None

    def train(self, t_m=None):
        # initial_train_model = copy.deepcopy(model)
        train_model = self.model if t_m is None else t_m
        train_model.train()
        total_loss = 0.0
        optimizer = self.get_optimizer(train_model)
        for e in range(self.epochs):
            # for step, (batch_x, batch_y) in enumerate(self.train_data):
            for batch_id, batch_data in enumerate(self.train_data):
                batch_x, batch_y = batch_data['image'], batch_data['label']
                train_model.zero_grad()
                batch_x = self.data_to_device(batch_x, device=self.device)
                batch_y = self.data_to_device(batch_y, device=self.device)
                outputs = self.model(batch_x)
                if t_m is not None:
                    out_de = self.de_model(batch_x)
                    _, pred = torch.max(out_de, -1)
                    loss = self.criterion(outputs, pred)
                else:
                    loss = self.criterion(outputs, batch_y)
                loss.backward()
                optimizer.step()
                batch_mean_loss = loss.item()
                total_loss += batch_mean_loss * len(batch_y)
        del optimizer
        return total_loss / (self.datavol * self.epochs)
